import numpy as np
from multiagent.environment_override import MultiAgentEnv
import multiagent.scenarios as scenarios
import gym_vecenv


def normalize_obs(obs, mean, std):
    if mean is not None:
        return np.divide((obs - mean), std)
    else:
        return obs


def make_env(env_id, seed, rank, num_agents, dist_threshold, arena_size, identity_size):
    def _thunk(process_index):
        env = make_multiagent_env(env_id, num_agents, dist_threshold, arena_size, identity_size, process_index = process_index)
        env.seed(seed + rank) # seed not implemented
        return env
    return _thunk


def make_multiagent_env(env_id, num_agents, dist_threshold, arena_size, identity_size, process_index=0):
    scenario = scenarios.load(env_id+".py").Scenario(num_agents=num_agents, dist_threshold=dist_threshold,
                                                     arena_size=arena_size, identity_size=identity_size, process_index=process_index)
    if not hasattr(scenario, 'info'):
        scenario.info = None
    
    env = MultiAgentEnv(scenario=scenario)
    return env


def make_parallel_envs(args):
    # make parallel environments
    envs = [make_env(args.env_name, args.seed, i, args.num_agents,
                     args.dist_threshold, args.arena_size, args.identity_size) for i in range(args.num_processes)]
    envs = gym_vecenv.SubprocVecEnv(envs)


    envs = gym_vecenv.MultiAgentVecNormalize(envs, ob=False, ret=True)
    return envs


def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module
